{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "234afff3",
   "metadata": {},
   "source": [
    "## Geneformer Fine-Tuning for Cell Annotation Application"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1cbe6178-ea4d-478a-80a8-65ffaa4c1820",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "GPU_NUMBER = [0]\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([str(s) for s in GPU_NUMBER])\n",
    "os.environ[\"NCCL_DEBUG\"] = \"INFO\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a9885d9f-00ac-4c84-b6a3-b7b648a90f0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "from collections import Counter\n",
    "import datetime\n",
    "import pickle\n",
    "import subprocess\n",
    "import seaborn as sns; sns.set()\n",
    "from datasets import load_from_disk\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "from transformers import BertForSequenceClassification\n",
    "from transformers import Trainer\n",
    "from transformers.training_args import TrainingArguments\n",
    "\n",
    "from geneformer import DataCollatorForCellClassification"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68bd3b98-5409-4105-b7af-f1ff64ea6a72",
   "metadata": {},
   "source": [
    "## Prepare training and evaluation datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "5735f1b7-7595-4a02-be17-2c5b970ad81a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# load cell type dataset (includes all tissues)\n",
    "train_dataset=load_from_disk(\"/path/to/cell_type_train_data.dataset\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4297a02-4c4c-434c-ae55-3387a0b239b5",
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "dataset_list = []\n",
    "evalset_list = []\n",
    "organ_list = []\n",
    "target_dict_list = []\n",
    "\n",
    "for organ in Counter(train_dataset[\"organ_major\"]).keys():\n",
    "    # collect list of tissues for fine-tuning (immune and bone marrow are included together)\n",
    "    if organ in [\"bone_marrow\"]:  \n",
    "        continue\n",
    "    elif organ==\"immune\":\n",
    "        organ_ids = [\"immune\",\"bone_marrow\"]\n",
    "        organ_list += [\"immune\"]\n",
    "    else:\n",
    "        organ_ids = [organ]\n",
    "        organ_list += [organ]\n",
    "    \n",
    "    print(organ)\n",
    "    \n",
    "    # filter datasets for given organ\n",
    "    def if_organ(example):\n",
    "        return example[\"organ_major\"] in organ_ids\n",
    "    trainset_organ = train_dataset.filter(if_organ, num_proc=16)\n",
    "    \n",
    "    # per scDeepsort published method, drop cell types representing <0.5% of cells\n",
    "    celltype_counter = Counter(trainset_organ[\"cell_type\"])\n",
    "    total_cells = sum(celltype_counter.values())\n",
    "    cells_to_keep = [k for k,v in celltype_counter.items() if v>(0.005*total_cells)]\n",
    "    def if_not_rare_celltype(example):\n",
    "        return example[\"cell_type\"] in cells_to_keep\n",
    "    trainset_organ_subset = trainset_organ.filter(if_not_rare_celltype, num_proc=16)\n",
    "      \n",
    "    # shuffle datasets and rename columns\n",
    "    trainset_organ_shuffled = trainset_organ_subset.shuffle(seed=42)\n",
    "    trainset_organ_shuffled = trainset_organ_shuffled.rename_column(\"cell_type\",\"label\")\n",
    "    trainset_organ_shuffled = trainset_organ_shuffled.remove_columns(\"organ_major\")\n",
    "    \n",
    "    # create dictionary of cell types : label ids\n",
    "    target_names = list(Counter(trainset_organ_shuffled[\"label\"]).keys())\n",
    "    target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))\n",
    "    target_dict_list += [target_name_id_dict]\n",
    "    \n",
    "    # change labels to numerical ids\n",
    "    def classes_to_ids(example):\n",
    "        example[\"label\"] = target_name_id_dict[example[\"label\"]]\n",
    "        return example\n",
    "    labeled_trainset = trainset_organ_shuffled.map(classes_to_ids, num_proc=16)\n",
    "    \n",
    "    # create 80/20 train/eval splits\n",
    "    labeled_train_split = labeled_trainset.select([i for i in range(0,round(len(labeled_trainset)*0.8))])\n",
    "    labeled_eval_split = labeled_trainset.select([i for i in range(round(len(labeled_trainset)*0.8),len(labeled_trainset))])\n",
    "    \n",
    "    # filter dataset for cell types in corresponding training set\n",
    "    trained_labels = list(Counter(labeled_train_split[\"label\"]).keys())\n",
    "    def if_trained_label(example):\n",
    "        return example[\"label\"] in trained_labels\n",
    "    labeled_eval_split_subset = labeled_eval_split.filter(if_trained_label, num_proc=16)\n",
    "\n",
    "    dataset_list += [labeled_train_split]\n",
    "    evalset_list += [labeled_eval_split_subset]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "83e20521-597a-4c54-897b-c4d42ea622c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "trainset_dict = dict(zip(organ_list,dataset_list))\n",
    "traintargetdict_dict = dict(zip(organ_list,target_dict_list))\n",
    "\n",
    "evalset_dict = dict(zip(organ_list,evalset_list))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10eb110d-ba43-4efc-bc43-1815d6912647",
   "metadata": {},
   "source": [
    "## Fine-Tune With Cell Classification Learning Objective and Quantify Predictive Performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "cd7b1cfb-f5cb-460e-ae77-769522ece054",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(pred):\n",
    "    labels = pred.label_ids\n",
    "    preds = pred.predictions.argmax(-1)\n",
    "    # calculate accuracy and macro f1 using sklearn's function\n",
    "    acc = accuracy_score(labels, preds)\n",
    "    macro_f1 = f1_score(labels, preds, average='macro')\n",
    "    return {\n",
    "      'accuracy': acc,\n",
    "      'macro_f1': macro_f1\n",
    "    }"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "beaab7a4-cc13-4e8f-b137-ed18ff7b633c",
   "metadata": {},
   "source": [
    "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example hyperparameters are defined below, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "d24e1ab7-0131-44bd-b458-1ce5ba31853e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# set model parameters\n",
    "# max input size\n",
    "max_input_size = 2 ** 11  # 2048\n",
    "\n",
    "# set training hyperparameters\n",
    "# max learning rate\n",
    "max_lr = 5e-5\n",
    "# how many pretrained layers to freeze\n",
    "freeze_layers = 0\n",
    "# number gpus\n",
    "num_gpus = 1\n",
    "# number cpu cores\n",
    "num_proc = 16\n",
    "# batch size for training and eval\n",
    "geneformer_batch_size = 12\n",
    "# learning schedule\n",
    "lr_schedule_fn = \"linear\"\n",
    "# warmup steps\n",
    "warmup_steps = 500\n",
    "# number of epochs\n",
    "epochs = 10\n",
    "# optimizer\n",
    "optimizer = \"adamw\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "05164c24-5fbf-4372-b26c-a43f3777a88d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "spleen\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='10280' max='10280' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [10280/10280 13:33, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>Macro F1</th>\n",
       "      <th>Weighted F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.087000</td>\n",
       "      <td>0.068067</td>\n",
       "      <td>0.985404</td>\n",
       "      <td>0.956839</td>\n",
       "      <td>0.985483</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.044400</td>\n",
       "      <td>0.075289</td>\n",
       "      <td>0.985079</td>\n",
       "      <td>0.955069</td>\n",
       "      <td>0.984898</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.066700</td>\n",
       "      <td>0.078703</td>\n",
       "      <td>0.983782</td>\n",
       "      <td>0.953240</td>\n",
       "      <td>0.983959</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.037400</td>\n",
       "      <td>0.057132</td>\n",
       "      <td>0.989945</td>\n",
       "      <td>0.970619</td>\n",
       "      <td>0.989883</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.025000</td>\n",
       "      <td>0.061644</td>\n",
       "      <td>0.988323</td>\n",
       "      <td>0.961126</td>\n",
       "      <td>0.988211</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.022400</td>\n",
       "      <td>0.065323</td>\n",
       "      <td>0.989296</td>\n",
       "      <td>0.969737</td>\n",
       "      <td>0.989362</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.018600</td>\n",
       "      <td>0.063710</td>\n",
       "      <td>0.989620</td>\n",
       "      <td>0.969436</td>\n",
       "      <td>0.989579</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.039800</td>\n",
       "      <td>0.065919</td>\n",
       "      <td>0.989945</td>\n",
       "      <td>0.968065</td>\n",
       "      <td>0.989802</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.030200</td>\n",
       "      <td>0.061359</td>\n",
       "      <td>0.990269</td>\n",
       "      <td>0.971700</td>\n",
       "      <td>0.990314</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.013400</td>\n",
       "      <td>0.059181</td>\n",
       "      <td>0.991567</td>\n",
       "      <td>0.974599</td>\n",
       "      <td>0.991552</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='257' max='257' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [257/257 00:07]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "kidney\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='29340' max='29340' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [29340/29340 45:43, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>Macro F1</th>\n",
       "      <th>Weighted F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.326900</td>\n",
       "      <td>0.299193</td>\n",
       "      <td>0.912500</td>\n",
       "      <td>0.823067</td>\n",
       "      <td>0.909627</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.224200</td>\n",
       "      <td>0.239580</td>\n",
       "      <td>0.926477</td>\n",
       "      <td>0.850237</td>\n",
       "      <td>0.923902</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.221600</td>\n",
       "      <td>0.242810</td>\n",
       "      <td>0.930227</td>\n",
       "      <td>0.878553</td>\n",
       "      <td>0.930349</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.166100</td>\n",
       "      <td>0.264178</td>\n",
       "      <td>0.933409</td>\n",
       "      <td>0.884759</td>\n",
       "      <td>0.933031</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.144100</td>\n",
       "      <td>0.279282</td>\n",
       "      <td>0.935000</td>\n",
       "      <td>0.887659</td>\n",
       "      <td>0.934987</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.112800</td>\n",
       "      <td>0.307647</td>\n",
       "      <td>0.935909</td>\n",
       "      <td>0.889239</td>\n",
       "      <td>0.935365</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.084600</td>\n",
       "      <td>0.326399</td>\n",
       "      <td>0.932841</td>\n",
       "      <td>0.892447</td>\n",
       "      <td>0.933191</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.068300</td>\n",
       "      <td>0.332626</td>\n",
       "      <td>0.936591</td>\n",
       "      <td>0.891629</td>\n",
       "      <td>0.936354</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.065500</td>\n",
       "      <td>0.348174</td>\n",
       "      <td>0.935227</td>\n",
       "      <td>0.889484</td>\n",
       "      <td>0.935040</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.046100</td>\n",
       "      <td>0.355350</td>\n",
       "      <td>0.935000</td>\n",
       "      <td>0.894578</td>\n",
       "      <td>0.934971</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='734' max='734' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [734/734 00:27]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "lung\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='21750' max='21750' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [21750/21750 30:32, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>Macro F1</th>\n",
       "      <th>Weighted F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.337600</td>\n",
       "      <td>0.341523</td>\n",
       "      <td>0.906360</td>\n",
       "      <td>0.759979</td>\n",
       "      <td>0.899310</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.211900</td>\n",
       "      <td>0.258954</td>\n",
       "      <td>0.928429</td>\n",
       "      <td>0.835534</td>\n",
       "      <td>0.925903</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.208600</td>\n",
       "      <td>0.282081</td>\n",
       "      <td>0.930421</td>\n",
       "      <td>0.842786</td>\n",
       "      <td>0.928013</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.144400</td>\n",
       "      <td>0.253047</td>\n",
       "      <td>0.935479</td>\n",
       "      <td>0.871712</td>\n",
       "      <td>0.935234</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.109200</td>\n",
       "      <td>0.268833</td>\n",
       "      <td>0.939464</td>\n",
       "      <td>0.876173</td>\n",
       "      <td>0.938870</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.132700</td>\n",
       "      <td>0.282697</td>\n",
       "      <td>0.940536</td>\n",
       "      <td>0.883271</td>\n",
       "      <td>0.940191</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.081800</td>\n",
       "      <td>0.295864</td>\n",
       "      <td>0.940843</td>\n",
       "      <td>0.884201</td>\n",
       "      <td>0.940170</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.035900</td>\n",
       "      <td>0.306600</td>\n",
       "      <td>0.941916</td>\n",
       "      <td>0.884777</td>\n",
       "      <td>0.941578</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.050800</td>\n",
       "      <td>0.311677</td>\n",
       "      <td>0.940536</td>\n",
       "      <td>0.883437</td>\n",
       "      <td>0.940294</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.035800</td>\n",
       "      <td>0.315360</td>\n",
       "      <td>0.940843</td>\n",
       "      <td>0.883551</td>\n",
       "      <td>0.940612</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='544' max='544' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [544/544 00:19]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "brain\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='8880' max='8880' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [8880/8880 11:14, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>Macro F1</th>\n",
       "      <th>Weighted F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.163100</td>\n",
       "      <td>0.156640</td>\n",
       "      <td>0.970345</td>\n",
       "      <td>0.736455</td>\n",
       "      <td>0.960714</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.149800</td>\n",
       "      <td>0.134897</td>\n",
       "      <td>0.968844</td>\n",
       "      <td>0.747114</td>\n",
       "      <td>0.960726</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.105600</td>\n",
       "      <td>0.115354</td>\n",
       "      <td>0.972222</td>\n",
       "      <td>0.775271</td>\n",
       "      <td>0.964932</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.086900</td>\n",
       "      <td>0.207918</td>\n",
       "      <td>0.968844</td>\n",
       "      <td>0.707927</td>\n",
       "      <td>0.958257</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.056400</td>\n",
       "      <td>0.106548</td>\n",
       "      <td>0.974099</td>\n",
       "      <td>0.839838</td>\n",
       "      <td>0.971611</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.037600</td>\n",
       "      <td>0.117437</td>\n",
       "      <td>0.978228</td>\n",
       "      <td>0.856578</td>\n",
       "      <td>0.975665</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.030500</td>\n",
       "      <td>0.127885</td>\n",
       "      <td>0.974474</td>\n",
       "      <td>0.856296</td>\n",
       "      <td>0.973531</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.019300</td>\n",
       "      <td>0.143203</td>\n",
       "      <td>0.977853</td>\n",
       "      <td>0.859362</td>\n",
       "      <td>0.975776</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.007400</td>\n",
       "      <td>0.153758</td>\n",
       "      <td>0.972598</td>\n",
       "      <td>0.852835</td>\n",
       "      <td>0.972314</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.017200</td>\n",
       "      <td>0.153911</td>\n",
       "      <td>0.975976</td>\n",
       "      <td>0.858196</td>\n",
       "      <td>0.974498</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='222' max='222' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [222/222 00:04]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "placenta\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='6180' max='6180' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [6180/6180 10:28, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>Macro F1</th>\n",
       "      <th>Weighted F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.128700</td>\n",
       "      <td>0.125175</td>\n",
       "      <td>0.960626</td>\n",
       "      <td>0.935752</td>\n",
       "      <td>0.959463</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.064000</td>\n",
       "      <td>0.215607</td>\n",
       "      <td>0.951456</td>\n",
       "      <td>0.920579</td>\n",
       "      <td>0.949828</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.051300</td>\n",
       "      <td>0.203044</td>\n",
       "      <td>0.961165</td>\n",
       "      <td>0.934195</td>\n",
       "      <td>0.959470</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.045300</td>\n",
       "      <td>0.115701</td>\n",
       "      <td>0.978964</td>\n",
       "      <td>0.966387</td>\n",
       "      <td>0.978788</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.048200</td>\n",
       "      <td>0.149484</td>\n",
       "      <td>0.973571</td>\n",
       "      <td>0.958927</td>\n",
       "      <td>0.973305</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.040900</td>\n",
       "      <td>0.134339</td>\n",
       "      <td>0.978964</td>\n",
       "      <td>0.967466</td>\n",
       "      <td>0.978899</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.001600</td>\n",
       "      <td>0.159900</td>\n",
       "      <td>0.978425</td>\n",
       "      <td>0.966713</td>\n",
       "      <td>0.978211</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.002400</td>\n",
       "      <td>0.125351</td>\n",
       "      <td>0.979504</td>\n",
       "      <td>0.968064</td>\n",
       "      <td>0.979428</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.009400</td>\n",
       "      <td>0.120132</td>\n",
       "      <td>0.980583</td>\n",
       "      <td>0.969631</td>\n",
       "      <td>0.980506</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.001500</td>\n",
       "      <td>0.137864</td>\n",
       "      <td>0.978964</td>\n",
       "      <td>0.967180</td>\n",
       "      <td>0.978825</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='155' max='155' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [155/155 00:05]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "immune\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='17140' max='17140' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [17140/17140 22:02, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>Macro F1</th>\n",
       "      <th>Weighted F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.288900</td>\n",
       "      <td>0.231582</td>\n",
       "      <td>0.936770</td>\n",
       "      <td>0.868405</td>\n",
       "      <td>0.934816</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.203200</td>\n",
       "      <td>0.206292</td>\n",
       "      <td>0.937354</td>\n",
       "      <td>0.888661</td>\n",
       "      <td>0.939555</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.183500</td>\n",
       "      <td>0.195811</td>\n",
       "      <td>0.944942</td>\n",
       "      <td>0.891149</td>\n",
       "      <td>0.944008</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.151000</td>\n",
       "      <td>0.219581</td>\n",
       "      <td>0.947665</td>\n",
       "      <td>0.906578</td>\n",
       "      <td>0.947093</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.090000</td>\n",
       "      <td>0.247120</td>\n",
       "      <td>0.946693</td>\n",
       "      <td>0.898812</td>\n",
       "      <td>0.945808</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.060400</td>\n",
       "      <td>0.249662</td>\n",
       "      <td>0.948444</td>\n",
       "      <td>0.905014</td>\n",
       "      <td>0.947975</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.071300</td>\n",
       "      <td>0.272767</td>\n",
       "      <td>0.949416</td>\n",
       "      <td>0.911514</td>\n",
       "      <td>0.949748</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.052600</td>\n",
       "      <td>0.305051</td>\n",
       "      <td>0.945331</td>\n",
       "      <td>0.902348</td>\n",
       "      <td>0.944987</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.026900</td>\n",
       "      <td>0.294135</td>\n",
       "      <td>0.948638</td>\n",
       "      <td>0.904058</td>\n",
       "      <td>0.948296</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.034500</td>\n",
       "      <td>0.292029</td>\n",
       "      <td>0.950195</td>\n",
       "      <td>0.908547</td>\n",
       "      <td>0.949753</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='429' max='429' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [429/429 00:13]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "large_intestine\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='33070' max='33070' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [33070/33070 43:02, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>Macro F1</th>\n",
       "      <th>Weighted F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.306200</td>\n",
       "      <td>0.312431</td>\n",
       "      <td>0.908266</td>\n",
       "      <td>0.786242</td>\n",
       "      <td>0.900768</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.223900</td>\n",
       "      <td>0.248096</td>\n",
       "      <td>0.925101</td>\n",
       "      <td>0.841251</td>\n",
       "      <td>0.920987</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.173600</td>\n",
       "      <td>0.259997</td>\n",
       "      <td>0.925907</td>\n",
       "      <td>0.850348</td>\n",
       "      <td>0.926290</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.162900</td>\n",
       "      <td>0.282306</td>\n",
       "      <td>0.925000</td>\n",
       "      <td>0.873669</td>\n",
       "      <td>0.925531</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.143400</td>\n",
       "      <td>0.254494</td>\n",
       "      <td>0.937903</td>\n",
       "      <td>0.876749</td>\n",
       "      <td>0.937836</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.104500</td>\n",
       "      <td>0.289942</td>\n",
       "      <td>0.934677</td>\n",
       "      <td>0.875333</td>\n",
       "      <td>0.934339</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.080300</td>\n",
       "      <td>0.313914</td>\n",
       "      <td>0.935484</td>\n",
       "      <td>0.877271</td>\n",
       "      <td>0.934986</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.063500</td>\n",
       "      <td>0.339868</td>\n",
       "      <td>0.936290</td>\n",
       "      <td>0.882267</td>\n",
       "      <td>0.936187</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.042500</td>\n",
       "      <td>0.345784</td>\n",
       "      <td>0.938911</td>\n",
       "      <td>0.882963</td>\n",
       "      <td>0.938682</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.038900</td>\n",
       "      <td>0.352199</td>\n",
       "      <td>0.939516</td>\n",
       "      <td>0.885509</td>\n",
       "      <td>0.939497</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='827' max='827' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [827/827 00:26]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "pancreas\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='18280' max='18280' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [18280/18280 23:32, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>Macro F1</th>\n",
       "      <th>Weighted F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.340100</td>\n",
       "      <td>0.343200</td>\n",
       "      <td>0.896244</td>\n",
       "      <td>0.655661</td>\n",
       "      <td>0.879469</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.178300</td>\n",
       "      <td>0.224033</td>\n",
       "      <td>0.930890</td>\n",
       "      <td>0.859772</td>\n",
       "      <td>0.925342</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.154200</td>\n",
       "      <td>0.208034</td>\n",
       "      <td>0.941284</td>\n",
       "      <td>0.887012</td>\n",
       "      <td>0.939485</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.121200</td>\n",
       "      <td>0.216660</td>\n",
       "      <td>0.940372</td>\n",
       "      <td>0.880716</td>\n",
       "      <td>0.939431</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.099900</td>\n",
       "      <td>0.254255</td>\n",
       "      <td>0.940554</td>\n",
       "      <td>0.889088</td>\n",
       "      <td>0.938300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.065800</td>\n",
       "      <td>0.267429</td>\n",
       "      <td>0.942743</td>\n",
       "      <td>0.897682</td>\n",
       "      <td>0.942815</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.061200</td>\n",
       "      <td>0.282509</td>\n",
       "      <td>0.945478</td>\n",
       "      <td>0.898797</td>\n",
       "      <td>0.943881</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.036800</td>\n",
       "      <td>0.301781</td>\n",
       "      <td>0.943837</td>\n",
       "      <td>0.903816</td>\n",
       "      <td>0.944163</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.035400</td>\n",
       "      <td>0.317026</td>\n",
       "      <td>0.942560</td>\n",
       "      <td>0.902241</td>\n",
       "      <td>0.942071</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.014200</td>\n",
       "      <td>0.313259</td>\n",
       "      <td>0.946754</td>\n",
       "      <td>0.904955</td>\n",
       "      <td>0.946129</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='457' max='457' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [457/457 00:11]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /n/home01/ctheodoris/models/210602_111318_geneformer_27M_L6_emb256_SL2048_E3_B12_LR0.001_LSlinear_WU10000_Oadamw_DS12/models/ and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias', 'classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "liver\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='18690' max='18690' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [18690/18690 26:56, Epoch 10/10]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>Macro F1</th>\n",
       "      <th>Weighted F1</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>0.388500</td>\n",
       "      <td>0.385503</td>\n",
       "      <td>0.878188</td>\n",
       "      <td>0.673887</td>\n",
       "      <td>0.871348</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>2</td>\n",
       "      <td>0.315900</td>\n",
       "      <td>0.302775</td>\n",
       "      <td>0.907437</td>\n",
       "      <td>0.754182</td>\n",
       "      <td>0.903474</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>3</td>\n",
       "      <td>0.242600</td>\n",
       "      <td>0.321844</td>\n",
       "      <td>0.907972</td>\n",
       "      <td>0.779504</td>\n",
       "      <td>0.905881</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>4</td>\n",
       "      <td>0.238600</td>\n",
       "      <td>0.323119</td>\n",
       "      <td>0.911539</td>\n",
       "      <td>0.790922</td>\n",
       "      <td>0.910299</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>5</td>\n",
       "      <td>0.160100</td>\n",
       "      <td>0.328203</td>\n",
       "      <td>0.915641</td>\n",
       "      <td>0.793490</td>\n",
       "      <td>0.913836</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>6</td>\n",
       "      <td>0.163100</td>\n",
       "      <td>0.348942</td>\n",
       "      <td>0.917425</td>\n",
       "      <td>0.813604</td>\n",
       "      <td>0.916911</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>7</td>\n",
       "      <td>0.124100</td>\n",
       "      <td>0.373799</td>\n",
       "      <td>0.916890</td>\n",
       "      <td>0.820355</td>\n",
       "      <td>0.916688</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>8</td>\n",
       "      <td>0.118700</td>\n",
       "      <td>0.399474</td>\n",
       "      <td>0.916890</td>\n",
       "      <td>0.818839</td>\n",
       "      <td>0.916640</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>9</td>\n",
       "      <td>0.066800</td>\n",
       "      <td>0.414363</td>\n",
       "      <td>0.917603</td>\n",
       "      <td>0.830703</td>\n",
       "      <td>0.917226</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>10</td>\n",
       "      <td>0.075800</td>\n",
       "      <td>0.413828</td>\n",
       "      <td>0.919030</td>\n",
       "      <td>0.828149</td>\n",
       "      <td>0.918506</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n",
      "<ipython-input-16-7f7bd5a45820>:54: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='936' max='468' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [468/468 00:39]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "for organ in organ_list:\n",
    "    print(organ)\n",
    "    organ_trainset = trainset_dict[organ]\n",
    "    organ_evalset = evalset_dict[organ]\n",
    "    organ_label_dict = traintargetdict_dict[organ]\n",
    "    \n",
    "    # set logging steps\n",
    "    logging_steps = round(len(organ_trainset)/geneformer_batch_size/10)\n",
    "    \n",
    "    # reload pretrained model\n",
    "    model = BertForSequenceClassification.from_pretrained(\"/path/to/pretrained_model/\", \n",
    "                                                      num_labels=len(organ_label_dict.keys()),\n",
    "                                                      output_attentions = False,\n",
    "                                                      output_hidden_states = False).to(\"cuda\")\n",
    "    \n",
    "    # define output directory path\n",
    "    current_date = datetime.datetime.now()\n",
    "    datestamp = f\"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}\"\n",
    "    output_dir = f\"/path/to/models/{datestamp}_geneformer_CellClassifier_{organ}_L{max_input_size}_B{geneformer_batch_size}_LR{max_lr}_LS{lr_schedule_fn}_WU{warmup_steps}_E{epochs}_O{optimizer}_F{freeze_layers}/\"\n",
    "    \n",
    "    # ensure not overwriting previously saved model\n",
    "    saved_model_test = os.path.join(output_dir, f\"pytorch_model.bin\")\n",
    "    if os.path.isfile(saved_model_test) == True:\n",
    "        raise Exception(\"Model already saved to this directory.\")\n",
    "\n",
    "    # make output directory\n",
    "    subprocess.call(f'mkdir {output_dir}', shell=True)\n",
    "    \n",
    "    # set training arguments\n",
    "    training_args = {\n",
    "        \"learning_rate\": max_lr,\n",
    "        \"do_train\": True,\n",
    "        \"do_eval\": True,\n",
    "        \"evaluation_strategy\": \"epoch\",\n",
    "        \"save_strategy\": \"epoch\",\n",
    "        \"logging_steps\": logging_steps,\n",
    "        \"group_by_length\": True,\n",
    "        \"length_column_name\": \"length\",\n",
    "        \"disable_tqdm\": False,\n",
    "        \"lr_scheduler_type\": lr_schedule_fn,\n",
    "        \"warmup_steps\": warmup_steps,\n",
    "        \"weight_decay\": 0.001,\n",
    "        \"per_device_train_batch_size\": geneformer_batch_size,\n",
    "        \"per_device_eval_batch_size\": geneformer_batch_size,\n",
    "        \"num_train_epochs\": epochs,\n",
    "        \"load_best_model_at_end\": True,\n",
    "        \"output_dir\": output_dir,\n",
    "    }\n",
    "    \n",
    "    training_args_init = TrainingArguments(**training_args)\n",
    "\n",
    "    # create the trainer\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=training_args_init,\n",
    "        data_collator=DataCollatorForCellClassification(),\n",
    "        train_dataset=organ_trainset,\n",
    "        eval_dataset=organ_evalset,\n",
    "        compute_metrics=compute_metrics\n",
    "    )\n",
    "    # train the cell type classifier\n",
    "    trainer.train()\n",
    "    predictions = trainer.predict(organ_evalset)\n",
    "    with open(f\"{output_dir}predictions.pickle\", \"wb\") as fp:\n",
    "        pickle.dump(predictions, fp)\n",
    "    trainer.save_metrics(\"eval\",predictions.metrics)\n",
    "    trainer.save_model(output_dir)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  },
  "vscode": {
   "interpreter": {
    "hash": "eba1599a1f7e611c14c87ccff6793920aa63510b01fc0e229d6dd014149b8829"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
